close all
clear

folderName = '2023_01_30_1_effective_U_modfreq3_1070_CK_DC';
prefix = 'scan';
paramVar ='stat_phase_pi';

% Seperate analysis for each evolution_dur:
paramVar2 = 'evolution_dur';
evolution_dur = 36.25;
plot_figure = 1;
export_data = 1;
stat_phase_pi = 0.5;

col1 = 115;
col2 = 127; 

row1 = 90; %difference must be even
row2 = 113; %starting site must be in the center
centersites = [(row2-row1+1)/2 (row2-row1+1)/2+1];
rlen = row2-row1+1;

[param_names, param_table] = get_batch_params(folderName); %param_table has extra zero at the end for some reason
paramInd = find(strcmp(param_names, paramVar));
allfvalues = param_table(:, paramInd);

% Restrict to the lines that have the good evolution dur:
param2Ind = find(strcmp(param_names, paramVar2));
alltvalues = param_table(:, param2Ind);
good_index = alltvalues == evolution_dur;
allfvalues = allfvalues(good_index);
allflength = length(allfvalues);

atomMatrix = zeros(248,248);
uniquefvals = unique(allfvalues);
flen = numel(unique(allfvalues));

doublonInd = find(strcmp(param_names, 'doublon_ramp_dur'));
doublon_ramp_dur_all = param_table(:, doublonInd);
doublon_ramp_dur_list = unique(doublon_ramp_dur_all);


%% Extract data

flen_aux = allflength / numel(doublon_ramp_dur_list);

% case 1: difference between atom positions = 0 (doublon)
% case 2: difference between atom positions = 1
% case 3: difference between atom positions > 1
Npost = zeros(flen_aux, 3, 2);
Nrealizations = zeros(flen_aux, 2);
bo = zeros(flen_aux, rlen, 3, 2);
bo_density = zeros(flen_aux, rlen, 3, 2);
corr_diff1_a = zeros(rlen, rlen, flen_aux);
corr_diff2_a = zeros(rlen, rlen, flen_aux);
corr_diff0_b = zeros(rlen, rlen, flen_aux);
corr_diff2_b = zeros(rlen, rlen, flen_aux);
corr_diff1_a_density = zeros(rlen, rlen, flen_aux);
corr_diff2_a_density = zeros(rlen, rlen, flen_aux);
corr_diff0_b_density = zeros(rlen, rlen, flen_aux);
corr_diff2_b_density = zeros(rlen, rlen, flen_aux);

P_left_right = zeros(flen_aux, 3, 2, 3);
for dd = 1:numel(doublon_ramp_dur_list)
    doublon_ramp_dur = doublon_ramp_dur_list(dd);

    % Find where the part with the good evolution_dur start:
    i_list = find(good_index);
    i_0 = i_list(1);
    start = i_0 + flen_aux*(dd-1); 
    stop = i_0 - 1 + flen_aux*dd;

    for jj = 1:(stop-start+1)
        listing = dir(fullfile(folderName, [prefix '*' num2str(start+jj-1,'%03.f') 'atomMatrix.mat']));
        Nshots = size(listing,1);
        Nrealizations(jj, dd) = (col2-col1+1) * Nshots;
        for i = 1:Nshots
            load(fullfile(folderName, listing(i).name));
            for col = col1:col2
                if sum(atomMatrix(row1:row2, col)) == 2
                    [row, ~, ~] = find(atomMatrix(row1:row2, col));
                
                    if doublon_ramp_dur > 0                       
                        if diff(row) > 1
                            Npost(jj,3,dd) = Npost(jj,3,dd) + 1;
                            bo(jj,:,3,dd) = bo(jj,:,3,dd) + atomMatrix(row1:row2, col)';
                            corr_diff2_b(row(1), row(2), jj) = corr_diff2_b(row(1), row(2), jj) + 1;
                            corr_diff2_b(row(2), row(1), jj) = corr_diff2_b(row(2), row(1), jj) + 1;
                            if row(2) <= centersites(1)
                                P_left_right(jj,3,dd,3) = P_left_right(jj,3,dd,3) + 1;
                            elseif row(1) >= centersites(2)
                                P_left_right(jj,3,dd,1) = P_left_right(jj,3,dd,1) + 1;
                            else
                                P_left_right(jj,3,dd,2) = P_left_right(jj,3,dd,2) + 1;
                            end
                        else
                            Npost(jj,1,dd) = Npost(jj,1,dd) + 1;
                            doublon_pos = row(2); % quic
                            atomRow = zeros(1, rlen);
                            atomRow(doublon_pos) = 2;
                            bo(jj,:,1,dd) = bo(jj,:,1,dd) + atomRow;
                            corr_diff0_b(doublon_pos, doublon_pos, jj) = corr_diff0_b(doublon_pos, doublon_pos, jj) + 2;
                            if row(1) < centersites(1) 
                                P_left_right(jj,1,dd,3) = P_left_right(jj,1,dd,3) + 1;
                            elseif row(1) >= centersites(1) 
                                P_left_right(jj,1,dd,1) = P_left_right(jj,1,dd,1) + 1;
                            else
                                P_left_right(jj,1,dd,2) = P_left_right(jj,1,dd,2) + 1;
                            end
                        end
                    else
                        if diff(row) > 1
                            Npost(jj,3,dd) = Npost(jj,3,dd) + 1;
                            bo(jj,:,3,dd) = bo(jj,:,3,dd) + atomMatrix(row1:row2, col)';
                            corr_diff2_a(row(1), row(2), jj) = corr_diff2_a(row(1), row(2), jj) + 1;
                            corr_diff2_a(row(2), row(1), jj) = corr_diff2_a(row(2), row(1), jj) + 1;
                            if row(2) <= centersites(1)
                                P_left_right(jj,3,dd,3) = P_left_right(jj,3,dd,3) + 1;
                            elseif row(1) >= centersites(2)
                                P_left_right(jj,3,dd,1) = P_left_right(jj,3,dd,1) + 1;
                            else
                                P_left_right(jj,3,dd,2) = P_left_right(jj,3,dd,2) + 1;
                            end
                        else
                            Npost(jj,2,dd) = Npost(jj,2,dd) + 1;
                            bo(jj,:,2,dd) = bo(jj,:,2,dd) + atomMatrix(row1:row2, col)';
                            corr_diff1_a(row(1), row(2), jj) = corr_diff1_a(row(1), row(2), jj) + 1;
                            corr_diff1_a(row(2), row(1), jj) = corr_diff1_a(row(2), row(1), jj) + 1;
                            if row(2) <= centersites(1) 
                                P_left_right(jj,2,dd,3) = P_left_right(jj,2,dd,3) + 1;
                            elseif row(1) >= centersites(2)
                                P_left_right(jj,2,dd,1) = P_left_right(jj,2,dd,1) + 1;
                            else
                                P_left_right(jj,2,dd,2) = P_left_right(jj,2,dd,2) + 1;
                            end
                        end
                    end
                end
            end
            corr_diff1_a_density(:,:,jj) = corr_diff1_a(:,:,jj) ./ repmat(Npost(jj,2,1), rlen, rlen);
            corr_diff2_a_density(:,:,jj) = corr_diff2_a(:,:,jj) ./ repmat(Npost(jj,3,1), rlen, rlen);
            corr_diff0_b_density(:,:,jj) = corr_diff0_b(:,:,jj) ./ repmat(Npost(jj,1,2), rlen, rlen);
            corr_diff2_b_density(:,:,jj) = corr_diff2_b(:,:,jj) ./ repmat(Npost(jj,3,2), rlen, rlen);
        end
    end

    for nn = 1:3
        bo_density(:,:,nn,dd) = bo(:,:,nn,dd) ./ repmat(Npost(:,nn,dd), 1, size(bo,2));
        P_left_right(:,nn,dd,:) = P_left_right(:,nn,dd,:) ./ repmat(Npost(:,nn,dd), 1, 1, 1, 3);
    end
end

bo_density(isnan(bo_density)) = 0;
corr_diff1_a_density(isnan(corr_diff1_a_density)) = 0;
corr_diff2_a_density(isnan(corr_diff2_a_density)) = 0;
corr_diff0_b_density(isnan(corr_diff0_b_density)) = 0;
corr_diff2_b_density(isnan(corr_diff2_b_density)) = 0;
Nrealizations_aux = sum(Nrealizations, 2);
P_left_right(isnan(P_left_right)) = 0; 


%% Combine data

Ndiff1_a = Npost(:,2,1);
Ndiff2_a = Npost(:,3,1);
Ndiff0_b = Npost(:,1,2);
Ndiff2_b = Npost(:,3,2);

Ndiff1_eff = 2 * Ndiff1_a;
Ndiff0_eff = 2 * Ndiff0_b;

bo_diff2_a = bo_density(:,:,3,1);
bo_diff2_b = bo_density(:,:,3,2);
bo_diff1_a = bo_density(:,:,2,1);
bo_diff0_b = bo_density(:,:,1,2);

Npost_total = Ndiff2_a + Ndiff2_b + Ndiff1_eff + Ndiff0_eff;
bo_density_weighted = repmat(Ndiff2_a, 1, rlen) .* bo_diff2_a + repmat(Ndiff2_b, 1, rlen) .* bo_diff2_b ...
    + repmat(Ndiff0_eff, 1, rlen) .* bo_diff0_b + repmat(Ndiff1_eff, 1, rlen) .* bo_diff1_a;
bo_density_weighted = bo_density_weighted ./ Npost_total;

corr_weighted = zeros(rlen, rlen, flen_aux);
for ff = 1:flen_aux
    corr_weighted(:,:,ff) = repmat(Ndiff2_a(ff), rlen, rlen) .* corr_diff2_a_density(:,:,ff) + repmat(Ndiff2_b(ff), rlen, rlen) .* corr_diff2_b_density(:,:,ff) ...
        + repmat(Ndiff0_eff(ff), rlen, rlen) .* corr_diff0_b_density(:,:,ff) + repmat(Ndiff1_eff(ff), rlen, rlen) .* corr_diff1_a_density(:,:,ff);
    corr_weighted(:,:,ff) = corr_weighted(:,:,ff) ./ repmat(Npost_total(ff), rlen, rlen);
end

% Left-right atom difference
P_left_right2_a = squeeze(P_left_right(:,3,1,:));
P_left_right2_b = squeeze(P_left_right(:,3,2,:));
P_left_right1_a = squeeze(P_left_right(:,2,1,:));
P_left_right0_b = squeeze(P_left_right(:,1,2,:));
P_left_right_weighted = repmat(Ndiff2_a, 1, 3) .* P_left_right2_a + repmat(Ndiff2_b, 1, 3) .* P_left_right2_b ...
    + repmat(Ndiff0_eff, 1, 3) .* P_left_right0_b + repmat(Ndiff1_eff, 1, 3) .* P_left_right1_a;
P_left_right_weighted = P_left_right_weighted ./ Npost_total;

% Compress data (if there are several data points per batch with same parameters)
Nrealizations_compressed = zeros(flen, 1);
Npost_compressed = zeros(flen, 1);

bo_compressed = zeros(flen, size(bo,2));
corr_compressed = zeros(rlen, rlen, flen);
fvalues = allfvalues(1:size(bo_density_weighted,1));
P_left_right_compressed = zeros(flen, 3);

for ff = 1:flen
    bo_compressed(ff,:) = mean(bo_density_weighted(fvalues == uniquefvals(ff),:), 1);
    corr_compressed(:,:,ff) = mean(corr_weighted(:,:,fvalues==uniquefvals(ff)), 3);
    Npost_compressed(ff) = sum(Npost_total(fvalues==uniquefvals(ff)));
    Nrealizations_compressed(ff) = sum(Nrealizations_aux(fvalues==uniquefvals(ff)));
    P_left_right_compressed(ff,:) = mean(P_left_right_weighted(fvalues == uniquefvals(ff),:), 1);
end

P_post_compressed = Npost_compressed ./ Nrealizations_compressed;
P_post_err = sqrt(P_post_compressed .* (1 - P_post_compressed) ./ Nrealizations_compressed);


%% Plot data

if plot_figure
    fig1 = figure('Name', ['Two-particle quantum walk']);
    colormap parula
    imagesc([12,-11], uniquefvals, bo_compressed)
    colorbar
    yticks([-0.5 0.5])
    xlabel('Row')
    ylabel('Phase (\pi)')
end


%% Difference right - left atom number

right_atom_num = zeros([1, flen]);
left_atom_num = zeros([1, flen]);

center_site = rlen/2;
for ff = 1:flen
    right_atom_num(ff) = sum(bo_compressed(ff, 1:center_site), 2);
    left_atom_num(ff) = sum(bo_compressed(ff, (center_site+1):end), 2);
end

diff_atom_num = right_atom_num - left_atom_num;
diff_atom_num_new = -2 * P_left_right_compressed(:,1) + 0 * P_left_right_compressed(:,2) + 2 * P_left_right_compressed(:,3);
std_diff_atom_num_new = sqrt( (-2)^2 * P_left_right_compressed(:,1) + 0^2 * P_left_right_compressed(:,2) + 2^2 * P_left_right_compressed(:,3)  - diff_atom_num_new.^2 );
err_diff_atom_num_new = std_diff_atom_num_new ./ sqrt(Npost_compressed);


%% Plot data

if plot_figure
    figure()
    hold on
    plot(uniquefvals, diff_atom_num, '-o')
    errorbar(uniquefvals, diff_atom_num_new, err_diff_atom_num_new, '--o')
    xlabel('Phase (\pi)')
    ylabel('\Delta n')
    hold off
end


%% Post selection rate

if plot_figure
    fig3 = figure();
    hold on
    plot(uniquefvals, P_post_compressed, 'o', 'Linewidth', 1.5)
    xlabel('Phase (\pi)')
    ylabel('Post selection rate')
    hold off
end


%% Export data

if export_data
    save(strcat('mat_files\m', folderName, '_delta_N_vs_modfreq3_evol_', replace(num2str(evolution_dur), '.', 'p'), 'ms.mat'), ...
        'uniquefvals', 'diff_atom_num_new', 'err_diff_atom_num_new',  'P_post_compressed', 'P_post_err');
end







